{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# torch转ONNX：split\n",
    "\n",
    "{func}`~torch.split` 将张量分割成块。每个块都是原始张量的视图。\n",
    "\n",
    "如果 `split_size_or_sections` 是整数类型，那么张量将被等分成大小相等的块（如果可能的话）。如果张量沿着给定维度 `dim` 的大小不能被 `split_size` 整除，最后一个块将会更小。\n",
    "\n",
    "如果 `split_size_or_sections` 是列表，那么张量将被分割成 `len(split_size_or_sections)` 个块，这些块在 `dim` 维度上的大小根据 `split_size_or_sections` 来确定。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mSignature:\u001b[0m\n",
      "\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mtensor\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0msplit_size_or_sections\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mdim\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m...\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mDocstring:\u001b[0m\n",
      "Splits the tensor into chunks. Each chunk is a view of the original tensor.\n",
      "\n",
      "If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will\n",
      "be split into equally sized chunks (if possible). Last chunk will be smaller if\n",
      "the tensor size along the given dimension :attr:`dim` is not divisible by\n",
      ":attr:`split_size`.\n",
      "\n",
      "If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split\n",
      "into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according\n",
      "to :attr:`split_size_or_sections`.\n",
      "\n",
      "Args:\n",
      "    tensor (Tensor): tensor to split.\n",
      "    split_size_or_sections (int) or (list(int)): size of a single chunk or\n",
      "        list of sizes for each chunk\n",
      "    dim (int): dimension along which to split the tensor.\n",
      "\n",
      "Example::\n",
      "\n",
      "    >>> a = torch.arange(10).reshape(5, 2)\n",
      "    >>> a\n",
      "    tensor([[0, 1],\n",
      "            [2, 3],\n",
      "            [4, 5],\n",
      "            [6, 7],\n",
      "            [8, 9]])\n",
      "    >>> torch.split(a, 2)\n",
      "    (tensor([[0, 1],\n",
      "             [2, 3]]),\n",
      "     tensor([[4, 5],\n",
      "             [6, 7]]),\n",
      "     tensor([[8, 9]]))\n",
      "    >>> torch.split(a, [1, 4])\n",
      "    (tensor([[0, 1]]),\n",
      "     tensor([[2, 3],\n",
      "             [4, 5],\n",
      "             [6, 7],\n",
      "             [8, 9]]))\n",
      "\u001b[0;31mFile:\u001b[0m      /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/torch/functional.py\n",
      "\u001b[0;31mType:\u001b[0m      function"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "torch.split?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm-book/doc/tutorials/frontend\n"
     ]
    }
   ],
   "source": [
    "%cd ../../..\n",
    "import set_env\n",
    "from d2py.utils.file import mkdir\n",
    "temp_dir = \".temp\"\n",
    "mkdir(temp_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "class Model(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 16, 1, 1, 0, bias=False, groups=1)\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        return torch.split(x, 8, dim=1) #, torch.split(x, [1, 3, 12], dim=1)\n",
    "\n",
    "shape = 1, 3, 8, 8\n",
    "x = torch.rand(*shape)\n",
    "\n",
    "torch_model = Model()\n",
    "# 导出模型\n",
    "output_name = \"split\"\n",
    "torch.onnx.export(\n",
    "    torch_model,               # torch 模型\n",
    "    x,                         # 模型输入或者对于多个输入，使用元组\n",
    "    f\"{temp_dir}/{output_name}.onnx\",             # 模型保存的位置（可以是文件或类似文件的对象）\n",
    "    export_params=True,        # 将训练后的参数权重存储在模型文件内\n",
    "    opset_version=9,          # 导出模型的 ONNX 版本\n",
    "    do_constant_folding=True,  # 是否执行常量折叠以进行优化\n",
    "    input_names = ['data'],    # 模型的输入名称\n",
    "    output_names = ['output'], # 模型的输出名称\n",
    "    # dynamic_axes={'data' : {0 : 'batch_size'},    # 可变长度的轴\n",
    "    #               'output' : {0 : 'batch_size'}}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnx\n",
    "import tvm\n",
    "from tvm import relay\n",
    "onnx_model = onnx.load(f\"{temp_dir}/{output_name}.onnx\")\n",
    "mod, params = relay.frontend.from_onnx(onnx_model, {\"data\": shape}, freeze_params=True)\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    mod = relay.quantize.prerequisite_optimize(mod, params)\n",
    "mod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    with relay.quantize.qconfig(\n",
    "        skip_conv_layers=[],\n",
    "        # calibrate_mode=\"kl_divergence\", \n",
    "        weight_scale=\"max\",\n",
    "        round_for_shift=True,\n",
    "        # rounding=\"TONEAREST\", # \"UPWARD\" or \"TONEAREST\"\n",
    "        # calibrate_skip_layers=[],\n",
    "        skip_dense_layer=False,\n",
    "    ):\n",
    "        qmod = relay.quantize.quantize(mod, params)\n",
    "qmod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
